Skip to content

[Feature][Blackwell] Add SM120 T.float4_e2m1fn FP4 GEMM support.#2171

Open
TerminusAkivili wants to merge 4 commits into
tile-ai:mainfrom
TerminusAkivili:sm120-fp4-a8w4-clean-pr
Open

[Feature][Blackwell] Add SM120 T.float4_e2m1fn FP4 GEMM support.#2171
TerminusAkivili wants to merge 4 commits into
tile-ai:mainfrom
TerminusAkivili:sm120-fp4-a8w4-clean-pr

Conversation

@TerminusAkivili
Copy link
Copy Markdown
Contributor

@TerminusAkivili TerminusAkivili commented May 8, 2026

Summary

This PR adds SM120 fragment-MMA GEMM support for T.float4_e2m1fn.
It covers plain FP4 GEMM and explicit mixed FP8/FP4 GEMM while keeping
the TileLang-facing API dtype-semantic.

Kernels continue to declare FP4 operands as T.float4_e2m1fn. Packed
byte storage is handled by lowering/codegen and by host-side example setup;
users do not have to model FP4 GEMM operands as uint8 tensors in TileLang
programs.

Supported GEMM combinations:

A dtype B dtype Accumulator SM120 path
T.float4_e2m1fn T.float4_e2m1fn T.float32 FP4 x FP4 fragment MMA
T.float8_e4m3fn T.float4_e2m1fn T.float32 A8W4 fragment MMA
T.float4_e2m1fn T.float8_e4m3fn T.float32 W4A8 fragment MMA

Design Goals

  • Preserve semantic TileLang signatures: FP4 tensors stay T.float4_e2m1fn
    at the language level.
  • Keep packed FP4 as a storage/lowering concern, not a public GEMM dtype
    workaround.
  • Add the SM120-specific hardware path without changing existing int4/uint4
    ldmatrix behavior.
  • Make the implementation reviewable by focusing this PR on SM120 FP4/A8W4
    GEMM.
  • Fail early for unsupported FP4/A8W4 MMA tile shapes instead of generating
    kernels that silently skip data.

Hardware Contracts

The SM120 FP4 path has three contracts that must line up across lowering,
layout, and template dispatch:

Contract Main implementation files
Shared FP4 operands are loaded with SM120 b4x16_p64 ldmatrix src/tl_templates/cuda/ldsm.h, src/backend/cuda/codegen/codegen_cuda.cc, src/backend/cuda/op/copy.cc, tilelang/cuda/intrinsics/layout/utils.py, tilelang/cuda/intrinsics/layout/mma_layout.py
Local FP4 fragments remain semantic FP4 objects for MMA operands src/backend/cuda/codegen/codegen_cuda.cc, src/backend/cuda/codegen/codegen_cuda.h
GEMM emits SM120 m16n8k32 MMA for explicit FP4/FP8 dtype pairs src/tl_templates/cuda/instruction/mma.h, src/tl_templates/cuda/gemm_mma.h, tilelang/cuda/intrinsics/macro/mma_macro_generator.py, tilelang/cuda/op/gemm/gemm_mma.py

The key implementation detail is that SM120 b4x16_p64 consumes packed FP4
bytes from shared memory with a padded shared row layout. Global memory,
shared memory, and local fragments therefore cannot all use the same offset
model:

  • Global FP4 storage is packed by logical nibbles.
  • Shared FP4 storage uses the padded layout required by SM120 ldmatrix.
  • Local fragments keep the declared semantic names and types expected by
    the MMA lowering path.

Implementation Details

Semantic FP4 Storage Model

This PR separates "what the TileLang program declares" from "how the FP4
payload is physically moved":

  • Language-level operands use T.float4_e2m1fn.
  • Host examples pack FP4 values into byte storage only as an interoperability
    detail.
  • Codegen recognizes packed global FP4 storage when computing memory offsets.
  • Shared-memory copies use the SM120 padded FP4 layout.
  • Local fragments are not rewritten into user-visible _packed aliases.

This avoids leaking uint8 into the public GEMM dtype model while still
allowing the generated CUDA path to use the byte/nibble representation that
the hardware requires.

SM120 FP4 LDSM

src/tl_templates/cuda/ldsm.h adds SM120 ptx_ldmatrix_b4x16_x{1,2,4}
helpers guarded for the target architecture. CUDA lowering selects these
helpers when the source fragment is explicitly float4_e2m1fn.

The Python layout side adds FP4-specific ldmatrix logical layouts and offset
handling. This is gated on float4_e2m1fn, so the existing int4/uint4
ldmatrix offset behavior stays on the previous path.

SM120 MMA Dispatch

The template dispatch adds SM120 cute::SM120_16x8x32_TN support for:

  • FP4 x FP4 -> FP32
  • FP8 e4m3 x FP4 -> FP32
  • FP4 x FP8 e4m3 -> FP32

The FP4 operand register payload adjustment is applied only to FP4 operands
before calling the CuTe atom. Mixed A8W4/W4A8 dispatch is selected from the
explicit A/B dtype pair, rather than inferred from a packed integer carrier.

Copy And Async Copy

The copy path distinguishes packed global FP4 offsets from padded shared FP4
offsets. Shared-to-fragment copy lowering routes SM120 FP4 through the new
b4x16_p64 ldmatrix path, while global-to-shared async copy keeps using the
existing cp.async lowering with FP4 padded-shared metadata enabled.

For FP4 global-to-shared async copy, the lowering emits 8-byte segments for
packed FP4 storage and carries the extra metadata needed to place those bytes
into the padded shared-memory layout.

K Tile Validation

SM120 FP4/A8W4 MMA consumes K in instruction-sized m16n8k32 tiles. A
T.gemm block K that is not divisible by the selected instruction K tile
cannot be represented by the current lowering without dropping the leftover
K range.

This PR therefore rejects unsupported K tile choices up front. For example,
block_K=48 is invalid for this path because it would execute one K=32 MMA
tile and silently omit the remaining K=16 tail.

Main Changes

CUDA Templates

File Change
src/tl_templates/cuda/ldsm.h Add SM120 ptx_ldmatrix_b4x16_x{1,2,4} helpers with architecture guard.
src/tl_templates/cuda/instruction/mma.h Add SM120 cute::SM120_16x8x32_TN dispatch for FP4xFP4, FP8xFP4, and FP4xFP8 to FP32.
src/tl_templates/cuda/instruction/mma.h Apply the FP4 operand register shift only for operands that are actually FP4.
src/tl_templates/cuda/gemm_mma.h Register FP4 and mixed FP8/FP4 GEMM template dispatch.
src/tl_templates/cuda/cuda_fp4.h Bridge TileLang FP4 template types to CuTe FP4 types while keeping existing packed helper types.

CUDA Lowering

File Change
src/backend/cuda/codegen/codegen_cuda.cc Select ptx_ldmatrix_b4x16_x{1,2,4} for explicit float4_e2m1fn ldmatrix loads.
src/backend/cuda/codegen/codegen_cuda.cc Distinguish packed global FP4, padded shared FP4, and semantic local FP4 storage.
src/backend/cuda/codegen/codegen_cuda.cc Keep local FP4 fragments under their declared names instead of introducing _packed aliases.
src/backend/cuda/codegen/codegen_cuda.h Carry the codegen-side helpers/state needed by semantic FP4 local fragments.
src/backend/cuda/op/copy.cc Extend shared-to-fragment LDSM lowering for SM120 FP4.
src/backend/cuda/op/copy.cc Route SM120 FP4 async copy through existing cp.async lowering with padded shared-copy handling enabled.
src/backend/cuda/op/copy_analysis.cc Carry copy-analysis metadata needed by the FP4 padded shared-copy path.
src/transform/lower_ptx_async_copy.cc Emit 8-byte FP4 global-to-shared async-copy segments for padded shared storage.
src/transform/ptx_async_copy_injector.h Preserve the metadata/flag used by FP4 padded shared-copy injection.

Python Lowering

File Change
tilelang/cuda/intrinsics/layout/utils.py Add FP4-specific ldmatrix offset handling, gated on float4_e2m1fn.
tilelang/cuda/intrinsics/layout/mma_layout.py Add FP4 ldmatrix logical layouts.
tilelang/cuda/intrinsics/macro/mma_macro_generator.py Use SM120 FP4 m16n8k32 MMA granularity.
tilelang/cuda/op/gemm/gemm_mma.py Validate mixed A/B dtypes explicitly for A8W4 and W4A8.
tilelang/cuda/op/gemm/gemm_mma.py Reject block K values that are not divisible by the selected MMA instruction K tile.

Examples

File Change
examples/gemm_fp4/example_gemm_fp4_sm120.py Minimal SM120 FP4 GEMM example with numerical check.
examples/gemm_fp4/example_gemm_a8w4_sm120.py Minimal SM120 A8W4 GEMM example with numerical check.

Changes Added After Initial Review

File Change
src/backend/cuda/codegen/codegen_cuda.cc Handle FP4 packed global vector load/store when the logical base offset is odd or not proven even. These cases now lower to per-lane nibble load/store operations instead of unsafe vector reinterpret access.
tilelang/cuda/op/gemm/gemm_mma.py Reject T.gemm K tiles that are not divisible by the selected MMA instruction K tile, preventing FP4/A8W4 cases such as block_K=48 from silently skipping the K tail.

Why The Review Fixes Matter

FP4 global packed storage is byte-addressed, but logical FP4 elements are
nibble-addressed. A vector reinterpret load/store is only safe when the
logical base offset is known to be even. If the offset is odd, or if codegen
cannot prove that it is even, vectorized byte reinterpretation can read or
write the wrong nibble without producing a compilation error.

Likewise, SM120 FP4/A8W4 MMA consumes K in fixed instruction-sized chunks.
Allowing a block_K such as 48 would make the generated kernel execute only
the representable K=32 portion and miss the K tail. The new validation turns
that silent numerical error into an explicit unsupported-shape error.

Validation

Local SM120 validation used an RTX 5090 / compute capability 12.0 environment.

Build and examples:

cmake --build build -j$(nproc)
PYTHONPATH=$PWD${PYTHONPATH:+:$PYTHONPATH} python examples/gemm_fp4/example_gemm_fp4_sm120.py
PYTHONPATH=$PWD${PYTHONPATH:+:$PYTHONPATH} python examples/gemm_fp4/example_gemm_a8w4_sm120.py

Observed numerical results:

example_gemm_fp4_sm120.py: max_abs_diff=0.000000
example_gemm_a8w4_sm120.py: max_abs_diff=0.000000, rel_err=0.000000

Generated CUDA was inspected for the expected SM120 FP4 markers:

tl::mma_sync<..., 16, 8, 32, false, true>
ptx_ldmatrix_b4x16_x4
FP4xFP4 and FP8xFP4 dtype dispatch

Focused tests:

PYTHONPATH=$PWD${PYTHONPATH:+:$PYTHONPATH} python -m pytest testing/python/kernel/test_tilelang_kernel_fp4_gemm.py -q
PYTHONPATH=$PWD${PYTHONPATH:+:$PYTHONPATH} python -m pytest testing/python/language/test_tilelang_language_copy.py -q -k 'fp4_odd_vector_start or copy_fp4'

Observed focused-test results:

testing/python/kernel/test_tilelang_kernel_fp4_gemm.py: 2 passed
testing/python/language/test_tilelang_language_copy.py -k 'fp4_odd_vector_start or copy_fp4': 2 passed

Notes And Non-Goals

  • Mixed A8W4/W4A8 dispatch is selected from explicit FP8/FP4 dtype pairs.
  • FP4 ldmatrix layout handling is gated on float4_e2m1fn.
  • Existing int4/uint4 ldmatrix offset behavior stays on the existing path.

Status

All set now!

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR implements SM120 (CUDA 12.0+) FP4 (float4_e2m1fn) GEMM support across TileLang: examples and host unpacking, CUDA/TI codegen FP4 storage/indexing/vector/scalar handling, FP4-aware cp.async injection, b4x16 ldmatrix helpers, CuTe SM120 MMA dispatch for FP4/mixed operands, layout/macro generation changes, and GemmMMA integration.

Changes

SM120 FP4 GEMM Support

Layer / File(s) Summary
Examples & Host Helpers
examples/gemm_fp4/...
Adds FP4 LUT constant, unpack_fp4_storage_to_float, require_sm120(), TileLang kernel generators (matmul_a8w4 / matmul_fp4), main() harnesses, deterministic/random inputs, zero-input checks, float32 reference comparisons, and error assertions.
TL Templates: LDSM & CUDA FP4 Types
src/tl_templates/cuda/ldsm.h, src/tl_templates/cuda/cuda_fp4.h
Adds SM120-only ptx_ldmatrix_b4x16_x1/x2/x4 helpers and expands cuda_fp4 compile-time guards and make_fp4_e2_64_t.
MMA Dispatch & Instruction Support
src/tl_templates/cuda/gemm_mma.h, src/tl_templates/cuda/instruction/mma.h
Maps fp4_e2_t to CuTe float_e2m1_t, registers SM120 16x8x32 TN dispatchers for FP4×FP4 and mixed FP8/FP4, and updates tl::mma_sync to left-shift FP4 operands before dispatcher invocation.
CUDA Codegen: Buffer / Vector / Scalar Access
src/backend/cuda/codegen/codegen_cuda.cc, src/backend/cuda/codegen/codegen_cuda.h
Centralizes FP4 storage classification, adds IsFp4* helpers and GetFp4PaddedSharedIndex, applies padded-shared index remapping and packed-byte divisor logic, and implements FP4-aware scalar/vector load-store codegen and cp.async/ldmatrix emission paths.
PTX Async Injector & FP4-padded cp.async
src/transform/lower_ptx_async_copy.cc, src/transform/ptx_async_copy_injector.h
Introduces fp4_padded_shared_copy flag and FP4-padded cp.async specialization that splits transfers into 16-FP4-element segments with padded index remapping; forwards flag through InjectPTXAsyncCopy/PTXAsyncCopyInjector.
Copy Lowering & LDSM Geometry
src/backend/cuda/op/copy.cc
Threads FP4 padded mode into Copy lowering, gates FP4 ldmatrix lowering to SM120 and non-transposed paths, computes elems_per_reg/elems_per_inst for 4-bit types, and updates vectorization, access_ptr extents, local loads, and loop unroll trip counts.
Copy Eligibility Analysis
src/backend/cuda/op/copy_analysis.cc
Adds FP4-specific gating: CheckLDSMCopy requires SM120 and exact src/dst dtype match for FP4; CheckSTSMCopy rejects STSM copies when either side is FP4.
Macro Generation & Layout Utilities
tilelang/cuda/intrinsics/macro/mma_macro_generator.py, tilelang/cuda/intrinsics/layout/*.py, tilelang/cuda/intrinsics/layout/utils.py
Special-cases float4_e2m1fn to k_dim=32, routes 4-/8-bit types through shared_16x32→mma_32x16 transforms, computes FP4-dependent access extents (4*num), and adds FP4-specific layout mapping helpers plus get_ldmatrix_offset support.
GemmMMA Integration
tilelang/cuda/op/gemm/gemm_mma.py
Adds FP8/FP4 dtype predicates, _validate_mma_dtypes() to enforce allowed mixed operand pairs (FP8+FP4 or identical), and allocates local fragments per operand dtype during lowering.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LJC00118
  • LeiWang1999
  • SiriusNEO

Poem

🐰 I nibble nibbles, pack them small,
SM120 wakes — kernels call.
Padded rows and cp.async chime,
TileLang hops through tiled time.
Rabbits cheer: GEMM runs fine.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 14.10% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The title clearly and specifically summarizes the main feature: adding SM120 support for FP4 (T.float4_e2m1fn) GEMM operations, which aligns with the substantial changes across CUDA templates, lowering, Python bindings, and examples.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 8, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
src/tl_templates/cuda/cuda_fp4.h (1)

166-187: ⚡ Quick win

Verify register allocation for fp4_e2_t values[64] in device code.

The 64-element local array is constant-indexed throughout (values[0]values[63]), so nvcc at -O2+ should scalar-replace it into registers. However, unlike the explicitly-parameterized make_fp4_e2_32_t which guarantees register-only arguments, register spilling to local memory is possible at lower optimisation levels or with larger surrounding register pressure. Consider adding a __forceinline__ annotation to maximise inlining and scalar replacement at call sites.

Proposed annotation
-template <typename... Args>
-TL_DEVICE fp4_e2_64_t make_fp4_e2_64_t(Args... args) {
+template <typename... Args>
+TL_DEVICE __forceinline__ fp4_e2_64_t make_fp4_e2_64_t(Args... args) {
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/cuda/cuda_fp4.h` around lines 166 - 187, The local array
fp4_e2_t values[64] in make_fp4_e2_64_t may be spilled under some compile
conditions; annotate the function to force inlining (e.g., add a
__forceinline__/always-inline device inline attribute to make_fp4_e2_64_t) so
nvcc can scalar-replace values[0]..values[63] into registers and inline the
make_fp4_e2_32_t calls; update the function declaration for make_fp4_e2_64_t
accordingly (keeping fp4_e2_t values[64] and the existing make_fp4_e2_32_t
usages unchanged).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/backend/cuda/codegen/codegen_cuda.cc`:
- Around line 1973-2003: The FP4 padded shared-memory vector path
(IsFp4PaddedSharedStorage + code using GetFp4PaddedSharedIndex and the
byte_offset lambda when constructing the reinterpret cast for t.lanes()) can
incorrectly span the padded 16-element row boundary; add a guard or split logic:
either assert the logical base alignment (e.g., Ensure base % 16 == 0 for the
requested load/store) or detect when the access crosses a 16-element row by
computing the start and end logical indices (base + offset and base + offset +
t.lanes()-1) and comparing their 16-element row indices (truncdiv(..., 16)); if
it crosses, split the operation into two row-aligned fragments (like the
existing t.lanes()==32 two-fragment approach) and merge them, otherwise keep the
current single contiguous byte reinterpretation; apply the same fix to the other
similar blocks identified (around the other ranges mentioned).
- Around line 4428-4444: The allocator treats only scope == "local" as the path
that emits local backing arrays but FP4 fragments use the semantic storage name
"local.fragment", so allocations for these still hit the unsupported-scope
branch; update the scope checks used around is_int4_scalar_local, the FP4
alignas(16) branch, and the place that prints/omits the storage scope to treat
"local.fragment" as equivalent to "local" (either normalize scope to "local"
earlier or change conditions from scope == "local" to (scope == "local" || scope
== "local.fragment")), ensuring PrintStorageScope/PrintType and the
backing-array emission path handle FP4 fragments the same as regular local
allocations (references: is_int4_scalar_local, op->dtype.is_float4_e2m1fn(),
PrintStorageScope, PrintType, and the "local.fragment" semantic storage).

In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py`:
- Around line 121-124: The FP4 fast-path in mma_macro_generator.py sets
self.k_dim = 32 without respecting self.chunk, causing micro_size_k to exceed
chunk when chunk < 32; update the FP4 branch in the initializer (the block
setting self.k_dim) to clamp k_dim by self.chunk (e.g., self.k_dim = min(32,
self.chunk)) and add the same clamp/guard in the subclass override (the code
around lines 873–877) so both places respect chunk; optionally emit a clear
ValueError or assertion if chunk < required minimum to fail early with a helpful
message referencing the dtype and chunk size.

---

Nitpick comments:
In `@src/tl_templates/cuda/cuda_fp4.h`:
- Around line 166-187: The local array fp4_e2_t values[64] in make_fp4_e2_64_t
may be spilled under some compile conditions; annotate the function to force
inlining (e.g., add a __forceinline__/always-inline device inline attribute to
make_fp4_e2_64_t) so nvcc can scalar-replace values[0]..values[63] into
registers and inline the make_fp4_e2_32_t calls; update the function declaration
for make_fp4_e2_64_t accordingly (keeping fp4_e2_t values[64] and the existing
make_fp4_e2_32_t usages unchanged).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a09f3145-ce2d-4b0d-bb75-d916a099b2be

📥 Commits

Reviewing files that changed from the base of the PR and between a797e51 and 140f774.

📒 Files selected for processing (16)
  • examples/gemm_fp4/example_gemm_a8w4_sm120.py
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • src/backend/cuda/codegen/codegen_cuda.cc
  • src/backend/cuda/codegen/codegen_cuda.h
  • src/backend/cuda/op/copy.cc
  • src/backend/cuda/op/copy_analysis.cc
  • src/tl_templates/cuda/cuda_fp4.h
  • src/tl_templates/cuda/gemm_mma.h
  • src/tl_templates/cuda/instruction/mma.h
  • src/tl_templates/cuda/ldsm.h
  • src/transform/lower_ptx_async_copy.cc
  • src/transform/ptx_async_copy_injector.h
  • tilelang/cuda/intrinsics/layout/mma_layout.py
  • tilelang/cuda/intrinsics/layout/utils.py
  • tilelang/cuda/intrinsics/macro/mma_macro_generator.py
  • tilelang/cuda/op/gemm/gemm_mma.py

Comment thread src/backend/cuda/codegen/codegen_cuda.cc
Comment thread src/backend/cuda/codegen/codegen_cuda.cc
Comment thread tilelang/cuda/intrinsics/macro/mma_macro_generator.py Outdated
@TerminusAkivili TerminusAkivili force-pushed the sm120-fp4-a8w4-clean-pr branch 3 times, most recently from 3e5823d to 7f254a9 Compare May 8, 2026 16:39
@TerminusAkivili TerminusAkivili changed the title [feature][Blackwell] Add SM120 FP4 and A8W4 GEMM support [feature][Blackwell] Add SM120 float4_e2m1fn FP4 GEMM support. May 8, 2026
@TerminusAkivili TerminusAkivili changed the title [feature][Blackwell] Add SM120 float4_e2m1fn FP4 GEMM support. [Feature][Blackwell] Add SM120 T.float4_e2m1fn FP4 GEMM support. May 11, 2026
@TerminusAkivili TerminusAkivili marked this pull request as draft May 11, 2026 15:26
@TerminusAkivili TerminusAkivili marked this pull request as ready for review May 11, 2026 16:06
@TerminusAkivili
Copy link
Copy Markdown
Contributor Author

TerminusAkivili commented May 11, 2026

Hi @LeiWang1999, no rush at all. Feel free to check it whenever it's convenient for you. I'd love your feedback. Thank you!

Lower FP4 packed vector load/store with odd or symbolic bases to per-lane nibble operations to avoid silent miscompiles.

Reject T.gemm K tiles that are not divisible by the MMA instruction K tile so FP4/A8W4 block_K tails cannot be silently skipped.
@TerminusAkivili TerminusAkivili force-pushed the sm120-fp4-a8w4-clean-pr branch from cb5bf3d to 795cb39 Compare May 12, 2026 07:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant